import numpy as np
import paddle
from paddle_quantum.circuit import UAnsatz


def U_theta(n, layer, depth, theta, x):
    t, lens = 0, len(x)
    cir = UAnsatz(n)
    for i in range(layer):
        # W(theta)
        if n==1:
            assert(depth==1)
            cir.u3(*theta[i][0][0], which_qubit=0)
        # two extension plan: if you use 2-qubit universal gate, remember to modify the initialization parameters
        if n == 2:
            cir.complex_entangled_layer(theta[i], depth=depth)
            # cir.universal_2_qubit_gate(theta[i], which_qubits=[0,1])

        # S(x)
        for j in range(n):
            cir.rz(x[t], j)
            t = (t+1)%lens

    # W(theta) L+1 trainable block
    if n==1:
        assert(depth==1)
        cir.u3(*theta[-1][0][0], which_qubit=0)
        # extension plan
    if n == 2:
        cir.complex_entangled_layer(theta[-1], depth=depth)
        # cir.universal_2_qubit_gate(theta[-1], which_qubits=[0,1])

    return cir


class QNN(paddle.nn.Layer):
    """
    """
    def __init__(self,
                 n,            # number of qubit
                 layer,        # number of layer L
                 depth,        # depth of each trainable block
                 ):
        super(QNN, self).__init__()
        self.num_qubits = n
        self.layer = layer
        self.depth = depth
        
        self.theta = self.create_parameter(
            shape=[layer+1, depth, n, 3],
            # two extension plan, if 2-qubit universal gate, need 15 parameters
            # shape = [layer+1, 15],
            default_initializer=paddle.nn.initializer.Uniform(0.0, 2*np.pi),
            dtype='float64',
            is_bias=False)
    
    def forward(self, x):
        """
        """
        predict = []
        H_info = [[1.0, 'z%s'%i] for i in range(self.num_qubits)]
     
        for i in range(x.shape[0]):
            cir = U_theta(self.num_qubits, self.layer, self.depth, self.theta, x[i])
            cir.run_state_vector()
            predict.append(cir.expecval(H_info))

        return paddle.concat(predict).reshape((-1,)), cir


def train_model(train_X, train_y, seed, N, LAYER, DEPTH, EPOCH=10, BATCH_SIZE=40, LR=0.1):
    """
    """
    paddle.seed(seed)
    net = QNN(N, LAYER, DEPTH)

    opt = paddle.optimizer.Adam(learning_rate=LR, parameters=net.parameters())

    train_loss = []

    idx = np.arange(train_X.shape[0])
    np.random.seed(0)
    np.random.shuffle(idx)
    train_X_random = paddle.to_tensor(train_X[idx], dtype="float64")
    train_y_random = paddle.to_tensor(train_y[idx], dtype="float64")

    for epoch in range(EPOCH):
        for j in range(train_X_random.shape[0]//BATCH_SIZE):
            batch_X = train_X_random[j:(j+1)*BATCH_SIZE]
            batch_y = train_y_random[j:(j+1)*BATCH_SIZE]
            predict, cir = net(batch_X)

            if epoch==0 and j==0:
                print(cir)

            avg_loss = paddle.mean((predict - batch_y) ** 2)
            train_loss.append(avg_loss.numpy())

            print("Epoch:%s ----- batch:%s ----- training loss %s"%(epoch, j, avg_loss.numpy()[0]))

            avg_loss.backward()
            opt.minimize(avg_loss)
            opt.clear_grad()
    
    train_X = paddle.to_tensor(train_X, dtype="float64")
    predict_y, cir = net(train_X)
    print(cir)
    
    return train_loss, predict_y